
import pandas as pd
from _funcs import parallel_read_csv, get_chunks, read_csv_chunk

dir = r'/Volumes/Zihao_SSD2/PatentsView/'

## ==================================================================================================================
## Robustness checks wrt choice of k
## Zihao Li. 06/2024

## Inputs:  temp/actual_citation_lst.csv
##          cleandata/sim_score_1981_2015_top10.csv
##          cleandata/sim_score_1981_2015_top50.csv
##          cleandata/sim_score_1981_2015_top5_kpss.csv

## Outputs: temp/omission_panel{k}_robust.csv for k=1-10
##          temp/omission_panel_flex.csv
##          temp/omission_panel_restrict.csv
## ==================================================================================================================


def fixed_k(actual_citation_file):
    ## Load actual citation data (generated in gen_omit_panel.py)
    print('Loading actual citation data...')
    df_actual_lst = parallel_read_csv(dir + actual_citation_file)

    ## Merge with artificial citation data
    print('Loading artificial citation data...')
    df_artificial = pd.read_csv(dir + 'cleandata/sim_score_1981_2015_top10.csv', low_memory=False)
    df_artificial = df_artificial.drop(columns=['patent_idx', 'cited_patent_idx', 'patent_year', 'cited_patent_year'])
    print('    Shape of artificial_citation dataset is {}.'.format(df_artificial.shape)) # (49532510, 3)
    df_artificial = df_artificial.sort_values(by=['patent_id', 'sim_score', 'cited_patent_id'], ascending=[True, False, False])
    df_artificial['rank'] = df_artificial.groupby('patent_id').cumcount() + 1
    
    ## Generate omission panel for different k (i.e. length of relevant list)
    for k in range(1, 11):
        print(f'Processing for k={k}...')
        df = df_artificial[df_artificial['rank'] <= k]
        print('    Shape of artificial_citation dataset is {}.'.format(df_artificial.shape))

        # Merging actual_citation_list with artificial citation data
        print('    Merging actual_citation_list with artificial citation data...')
        df = df.merge(df_actual_lst, how='left', on='patent_id')
        print('    Shape of merged dataset is {}.'.format(df.shape))

        # Generating omission index
        print('    Generating omission index...')
        df = df.dropna(subset=['actual_citation_list'])
        df['omission'] = df.apply(lambda row: 0 if str(row['cited_patent_id']) in row['actual_citation_list'] else 1, axis=1)
        print('    Shape of omission dataset is {}.'.format(df.shape))

        # Exporting omission panel
        print('    Exporting omission panel...')
        df.to_csv(dir + f'temp/omission_panel{k}_robust.csv', index=False)
        del df


def flexible_k(actual_citation_file):
    ## Load actual citation data
    print('Loading actual citation data...')
    df_actual_lst = parallel_read_csv(dir + actual_citation_file)  
    
    print('Loading artificial citation data...')
    df_artificial = pd.read_csv(dir + 'cleandata/sim_score_1981_2015_top50.csv', low_memory=False)
    df_artificial = df_artificial.drop(columns=['patent_idx', 'cited_patent_idx', 'patent_year', 'cited_patent_year'])
    df_artificial = df_artificial.sort_values(by=['patent_id', 'sim_score', 'cited_patent_id'], ascending=[True, False, False])
    df_artificial['rank'] = df_artificial.groupby('patent_id').cumcount() + 1

    print('Merging actual_citation_list with artificial citation data...')
    df = df_artificial.merge(df_actual_lst, how='left', on='patent_id'); del df_artificial, df_actual_lst
    df = df[df['num_citations'] <= 32]  # 32 is the 90th percentile of number of actual citations
    df = df[df['rank'] <= df['num_citations']]

    print('Generating omission index...')
    df = df.dropna(subset=['actual_citation_list'])
    df['omission'] = df.apply(lambda row: 0 if str(row['cited_patent_id']) in row['actual_citation_list'] else 1, axis=1)
    print('Shape of omission dataset is {}.'.format(df.shape)) # (38561798, 7)

    ## Export omission panel
    print('Exporting omission panel...')
    df.to_csv(dir + 'temp/omission_panel_flex.csv', index=False)


def restrict_k(actual_citation_file):
    ## Load actual citation data
    print('Loading actual citation data...')
    df_actual_lst = parallel_read_csv(dir + actual_citation_file)  
    
    print('Loading artificial citation data...')
    df_artificial = pd.read_csv(dir + 'cleandata/sim_score_1981_2015_top5_kpss.csv') # KPSS firms only
    df_artificial = df_artificial.drop(columns=['patent_idx', 'cited_patent_idx', 'patent_year', 'cited_patent_year'])
    df_artificial = df_artificial.sort_values(by=['patent_id', 'sim_score', 'cited_patent_id'], ascending=[True, False, False])

    print('Merging actual_citation_list with artificial citation data...')
    df = df_artificial.merge(df_actual_lst, how='left', on='patent_id'); del df_artificial, df_actual_lst
    print('Shape of merged dataset is {}.'.format(df.shape)) # (9065575, 5)

    print('Generating omission index...')
    df = df.dropna(subset=['actual_citation_list'])
    df['omission'] = df.apply(lambda row: 0 if str(row['cited_patent_id']) in row['actual_citation_list'] else 1, axis=1)
    print('Shape of omission dataset is {}.'.format(df.shape)) # (8737095, 6)

    ## Export omission panel
    print('Exporting omission panel...')
    df.to_csv(dir + 'temp/omission_panel_restrict.csv', index=False)


def main():
    actual_citation_file = 'temp/actual_citation_lst.csv'
    fixed_k(actual_citation_file=actual_citation_file)
    flexible_k(actual_citation_file=actual_citation_file)
    restrict_k(actual_citation_file=actual_citation_file)

if __name__ == "__main__":
    main()